package com.xiaoyu.tio.redis.core.handler;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.client.intf.ClientAioHandler;
import org.tio.core.ChannelContext;
import org.tio.core.GroupContext;
import org.tio.core.exception.AioDecodeException;
import org.tio.core.intf.Packet;
import org.tio.core.utils.ByteBufferUtils;

import com.alibaba.fastjson.JSON;
import com.xiaoyu.tio.redis.core.MessageResponse;
import com.xiaoyu.tio.redis.core.RedisPacket;

public class RedisHandler implements ClientAioHandler {

	private final static Logger logger = LoggerFactory.getLogger(RedisHandler.class);

	private static final String CRLF = "\r\n";

	private static final byte CR_BYTE = 13;

	private static final byte LF_BYTE = 10;

	private static final String MULTI_BULK_REPLY = "*";

	private static final byte MULTI_BULK_BYTE = 42;

	private static final String BULK_REPLY = "$";

	private static final byte BULK_BYTE = 36;

	private static final String INT_REPLY = ":";

	private static final byte INT_BYTE = 58;

	private static final String STATUS_REPLY = "+";

	private static final byte STATUS_BYTE = 43;

	private static final String ERROR_REPLY = "-";

	private static final byte ERROR_BYTE = 45;

	private static final int MIN_LENGTH = 3;

	private final ConcurrentHashMap<Long, CompletableFuture<MessageResponse>> data = new ConcurrentHashMap<>();

	private final ConcurrentLinkedQueue<Long> stack = new ConcurrentLinkedQueue<>();

	@Override
	public Packet decode(ByteBuffer buffer, int limit, int position, int readableLength, ChannelContext channelContext)
			throws AioDecodeException {
		if (MIN_LENGTH >= readableLength) {
			return null;
		}
		RedisPacket rp = new RedisPacket();
		byte[] dst = new byte[readableLength];
		byte[] need = null;
		Integer end = null;
		System.arraycopy(buffer.array(), position, dst, 0, readableLength);
		loop: switch (dst[0]) {
		case BULK_BYTE:
			end = getPacketLength(dst, 0, 1);
			if (end == null) {
				return null;
			}
			break loop;
		case MULTI_BULK_BYTE:
			int n = 0;
			for (int z = 1; z < dst.length; z++) {
				if (dst[z] == CR_BYTE && dst[z + 1] == LF_BYTE) {
					byte[] h = new byte[n];
					for (int p = 1, x = 0; p < z; p++, x++) {
						h[x] = dst[p];
					}
					Integer t = null;
					try {
						t = Integer.valueOf(new String(h));
					} catch (NumberFormatException e) {
						throw new AioDecodeException("协议格式错误！");
					}
					if (t <= 0) {
						end = z + 2;
					} else {
						end = getPacketLength(dst, z + 2, t);
						if (end == null) {
							return null;
						}
					}
					break;
				} else {
					n++;
				}
			}
			break loop;
		case INT_BYTE:
			end = nextCRLFIndex(dst, 0);
			if (end == null) {
				return null;
			}
			end = end + 1;
			break loop;
		case STATUS_BYTE:
			end = nextCRLFIndex(dst, 0);
			if (end == null) {
				return null;
			}
			end = end + 1;
			break loop;
		case ERROR_BYTE:
			end = nextCRLFIndex(dst, 0);
			if (end == null) {
				return null;
			}
			end = end + 1;
			break loop;
		default:
			throw new AioDecodeException("未知redis命令  !!");
		}
		need = new byte[end];
		buffer.get(need);
		rp.setRes(need);
		return rp;
	}

	/**
	 * 从指定的下标，从数组中搜索下一个结束符（\r\n）的下标，超过包的长度返回null
	 * 
	 * @param startIndex
	 * @param dst
	 * @return
	 */
	private static Integer nextCRLFIndex(byte[] dst, int startIndex) {
		if (startIndex + 1 >= dst.length) {
			return null;
		}
		for (int z = startIndex; z < dst.length; z++) {
			int lfIndex = z + 1;
			if (lfIndex >= dst.length) {
				return null;
			}
			if (dst[z] == CR_BYTE && dst[lfIndex] == LF_BYTE) {
				return lfIndex;
			}
		}
		return null;
	}

	/**
	 * 从字节数组中获取完整包的长度
	 * 
	 * @param dst
	 * @param startIndex
	 * @param parmNum
	 * @return
	 */
	private Integer getPacketLength(byte[] dst, int startIndex, int parmNum) {
		int i = 0;
		Integer endIndex = null;
		for (int j = startIndex; j < dst.length; j++) {
			int n = 0;
			if (i < parmNum) {
				loop: switch (dst[j]) {
				case BULK_BYTE:
					for (int z = j + 1; z < dst.length; z++) {
						int lfIndex = z + 1;
						if (lfIndex >= dst.length) {
							return null;
						}
						if (dst[z] == CR_BYTE && dst[lfIndex] == LF_BYTE) {
							byte[] h = new byte[n];
							for (int p = j + 1, x = 0; p < z; p++, x++) {
								h[x] = dst[p];// 获取参数长度的值
							}
							Integer t = Integer.valueOf(new String(h));// 转换成int类型
							if (t < 0) {
								endIndex = lfIndex;
							} else {
								endIndex = t + lfIndex + 2;// 得到完整参数的长度
							}
							if (endIndex + 1 > dst.length) {
								return null;
							} else {
								j = endIndex;
								i++;
								break;
							}
						} else {
							n++;
						}
					}
					break loop;
				case INT_BYTE:
					endIndex = nextCRLFIndex(dst, j);
					if (endIndex == null) {
						return null;
					} else {
						j = endIndex;
						i++;
					}
					break loop;
				case STATUS_BYTE:
					endIndex = nextCRLFIndex(dst, j);
					if (endIndex == null) {
						return null;
					} else {
						j = endIndex;
						i++;
					}
					break loop;
				case ERROR_BYTE:
					endIndex = nextCRLFIndex(dst, j);
					if (endIndex == null) {
						return null;
					} else {
						j = endIndex;
						i++;
					}
					break loop;
				}
			} else {
				break;
			}
		}
		if (i != parmNum) {
			return null;
		}
		return endIndex + 1;
	}

	@Override
	public ByteBuffer encode(Packet packet, GroupContext groupContext, ChannelContext channelContext) {
		RedisPacket rp = (RedisPacket) packet;
		String data = null;
		String cmd = rp.getCmd();
		List<byte[]> parms = rp.getParms();
		int parmNum = 1 + parms.size();// 参数数量
		int parmOneLength = cmd.length();// 命令的的字节数
		data = MULTI_BULK_REPLY + parmNum + CRLF + BULK_REPLY + parmOneLength + CRLF + cmd + CRLF;
		ByteBuffer headBuffer = ByteBuffer.allocate(data.length());
		headBuffer.put(data.getBytes());
		headBuffer.flip();
		ByteBuffer fullBuffer = headBuffer;
		for (int i = 0; i < parms.size(); i++) {
			int size = (parms.get(i).length + "").getBytes().length + parms.get(i).length + 5;
			ByteBuffer b = ByteBuffer.allocate(size);
			b.put(BULK_BYTE);
			b.put((parms.get(i).length + "").getBytes());
			b.put(CR_BYTE);
			b.put(LF_BYTE);
			b.put(parms.get(i));
			b.put(CR_BYTE);
			b.put(LF_BYTE);
			b.flip();
			fullBuffer = ByteBufferUtils.composite(fullBuffer, b);
		}
		stackAdd(rp.getId());
		addMessage(rp.getId(), rp.getFuture());
		fullBuffer.order(groupContext.getByteOrder());
		return fullBuffer;
	}

	@Override
	public void handler(Packet packet, ChannelContext channelContext) throws Exception {
		RedisPacket helloPacket = (RedisPacket) packet;
		byte[] body = helloPacket.getRes();
		if (body != null) {
			List<byte[]> data = replyData(body);
			Long id = stack.poll();
			if (id != null) {
				MessageResponse messageResponse = new MessageResponse();
				messageResponse.setData(data);
				messageResponse.setId(id);
				CompletableFuture<MessageResponse> f = this.data.get(id);
				f.complete(messageResponse);
				this.data.remove(id);
			} else {
				logger.warn("未找到消息ID:{} 数据:{}", id, JSON.toJSONString(data));
			}
		} else {
			logger.error("数据格式错误:{}" + new String(body));
		}
	}

	@Override
	public Packet heartbeatPacket() {
		return null;
	}

	public static void main(String[] args) {
		// String a = "*2\r\n:1\r\n$1\r\n1\r\n";
		// String a = ":1123\r\n";
		// String a = "-错误的数据\r\n";
		// String a = "+OK\r\n";
		String a = "*2\r\n:1\r\n$3\r\n1as\r\n";
		// replyData(a.getBytes());
	}

	private List<byte[]> replyData(byte[] body) {
		List<byte[]> data = new ArrayList<>();
		switch (body[0]) {
		case MULTI_BULK_BYTE:
			int lfIndex = nextCRLFIndex(body, 0);
			byte[] h = Arrays.copyOfRange(body, 1, lfIndex - 1);
			Integer t = Integer.valueOf(new String(h));// 转换成int类型
			if (t > 0) {
				int n = 0;
				for (int j = lfIndex + 1; j < body.length; j++) {
					if (n < t) {
						switch (body[j]) {
						case BULK_BYTE:
							int a = nextCRLFIndex(body, j);
							byte[] b = Arrays.copyOfRange(body, j + 1, a - 1);
							int length = Integer.valueOf(new String(b));
							byte[] d = Arrays.copyOfRange(body, a + 1, a + length + 1);
							data.add(d);
							n++;
							j = a + length + 2;
							break;
						case INT_BYTE:
							int a1 = nextCRLFIndex(body, j);
							byte[] v = Arrays.copyOfRange(body, j + 1, a1 - 1);
							data.add(v);
							n++;
							j = a1;
							break;
						case ERROR_BYTE:
							int a2 = nextCRLFIndex(body, j);
							byte[] v1 = Arrays.copyOfRange(body, j + 1, a2 - 1);
							data.add(v1);
							n++;
							j = a2;
							break;
						case STATUS_BYTE:
							int a3 = nextCRLFIndex(body, j);
							byte[] v2 = Arrays.copyOfRange(body, j + 1, a3 - 1);
							data.add(v2);
							n++;
							j = a3;
							break;
						default:
							break;
						}
					} else {
						break;
					}
				}
			}
			break;
		case BULK_BYTE:
			int a = nextCRLFIndex(body, 0);
			byte[] b = Arrays.copyOfRange(body, 1, a - 1);
			int length = Integer.valueOf(new String(b));
			if(length > -1) {
				byte[] d = Arrays.copyOfRange(body, a + 1, a + length + 1);
				data.add(d);
			}else {
				data.add(null);
			}
			break;
		case INT_BYTE:
			byte[] v = getCMDValue(body);
			data.add(v);
			break;
		case ERROR_BYTE:
			byte[] v1 = getCMDValue(body);
			data.add(v1);
			break;
		case STATUS_BYTE:
			byte[] v2 = getCMDValue(body);
			data.add(v2);
			break;
		default:
			break;
		}
		return data;
	}

	private byte[] getCMDValue(byte[] body) {
		int a3 = nextCRLFIndex(body, 0);
		byte[] v2 = Arrays.copyOfRange(body, 1, a3 - 1);
		return v2;
	}

	private List<Object> replyHandler(String[] cmd) {
		List<Object> result = new ArrayList<Object>();
		for (int i = 0; i < cmd.length; i++) {
			String cs = cmd[i].substring(0, 1);
			String v = cmd[i].substring(1);
			loop: switch (cs) {
			case BULK_REPLY:
				Integer dataLength = Integer.valueOf(v);
				if (dataLength == -1) {
					result.add(null);
				}
				break loop;
			case INT_REPLY:
				Integer number = Integer.valueOf(v);
				result.add(number);
				break loop;
			case ERROR_REPLY:
				result.add(v);
				break loop;
			case STATUS_REPLY:
				result.add(true);
				break loop;
			case MULTI_BULK_REPLY:

				break loop;
			default:
				result.add(cmd[i]);
				break loop;
			}
		}
		return result;
	}

	private void stackAdd(Long id) {
		stack.add(id);
	}

	private void addMessage(Long id, CompletableFuture<MessageResponse> messageResponse) {
		data.put(id, messageResponse);
	}

}
